'''
Author: 梦付千秋星垂野 465943794@qq.com
Date: 2023-03-27 15:26:36
LastEditors: 梦付千秋星垂野 465943794@qq.com
LastEditTime: 2023-04-11 14:25:17
FilePath: /base_machinelearning/DigitRecognizer/train.py
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
'''
import pandas as pd
import torch
from utils.DataLoader import MNIST_data
from torchvision import transforms
from utils.DataLoader import RandomRotation,RandomShift
from model.AlexNet import Net
# from model.Net import Net
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
train_df = pd.read_csv('../datasets/digit-recognizer/train.csv')
n_train = len(train_df)
n_pixels = len(train_df.columns) - 1
n_class = len(set(train_df['label']))
batch_size = 64

train_dataset = MNIST_data('../datasets/digit-recognizer/train.csv', n_pixels =n_pixels,transform= transforms.Compose(
                            [transforms.ToPILImage(), RandomRotation(degrees=20), RandomShift(3),
                             transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))]))
# test_dataset = MNIST_data('../datasets/digit-recognizer/test.csv')


train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, shuffle=True)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                            batch_size=batch_size, shuffle=False)
model = Net()

optimizer = optim.Adam(model.parameters(), lr=0.003)

criterion = nn.CrossEntropyLoss()

exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.8)

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()
def train(epoch):
    model.train()
    

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        if (batch_idx + 1)% 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                100. * (batch_idx + 1) / len(train_loader), 
                loss.item()))
    exp_lr_scheduler.step()
def evaluate(data_loader):
    model.eval()
    loss = 0
    correct = 0
    
    for data, target in data_loader:
        data, target = Variable(data, volatile=True), Variable(target)
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
        
        output = model(data)
        
        loss += F.cross_entropy(output, target, size_average=False).item()

        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        
    loss /= len(data_loader.dataset)
        
    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
        loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))
n_epochs = 50

for epoch in range(n_epochs):
    train(epoch)
    evaluate(train_loader)
torch.save(model.state_dict(), 'epoch10-AlexNet.pth')